{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler\n",
    "from pyspark.ml.pipeline import Pipeline, PipelineModel\n",
    "from pyspark.ml.classification import RandomForestClassifier\n",
    "from pyspark.ml.classification import DecisionTreeClassifier\n",
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = (spark.read\n",
    "      .option(\"inferSchema\", True)\n",
    "      .option(\"header\", True)\n",
    "      .csv(\"/data/creditcard-fraud.csv\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Time</th>\n",
       "      <th>V1</th>\n",
       "      <th>V2</th>\n",
       "      <th>V3</th>\n",
       "      <th>V4</th>\n",
       "      <th>V5</th>\n",
       "      <th>V6</th>\n",
       "      <th>V7</th>\n",
       "      <th>V8</th>\n",
       "      <th>V9</th>\n",
       "      <th>...</th>\n",
       "      <th>V21</th>\n",
       "      <th>V22</th>\n",
       "      <th>V23</th>\n",
       "      <th>V24</th>\n",
       "      <th>V25</th>\n",
       "      <th>V26</th>\n",
       "      <th>V27</th>\n",
       "      <th>V28</th>\n",
       "      <th>Amount</th>\n",
       "      <th>Class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>-1.359807</td>\n",
       "      <td>-0.072781</td>\n",
       "      <td>2.536347</td>\n",
       "      <td>1.378155</td>\n",
       "      <td>-0.338321</td>\n",
       "      <td>0.462388</td>\n",
       "      <td>0.239599</td>\n",
       "      <td>0.098698</td>\n",
       "      <td>0.363787</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.018307</td>\n",
       "      <td>0.277838</td>\n",
       "      <td>-0.110474</td>\n",
       "      <td>0.066928</td>\n",
       "      <td>0.128539</td>\n",
       "      <td>-0.189115</td>\n",
       "      <td>0.133558</td>\n",
       "      <td>-0.021053</td>\n",
       "      <td>149.62</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>1.191857</td>\n",
       "      <td>0.266151</td>\n",
       "      <td>0.166480</td>\n",
       "      <td>0.448154</td>\n",
       "      <td>0.060018</td>\n",
       "      <td>-0.082361</td>\n",
       "      <td>-0.078803</td>\n",
       "      <td>0.085102</td>\n",
       "      <td>-0.255425</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.225775</td>\n",
       "      <td>-0.638672</td>\n",
       "      <td>0.101288</td>\n",
       "      <td>-0.339846</td>\n",
       "      <td>0.167170</td>\n",
       "      <td>0.125895</td>\n",
       "      <td>-0.008983</td>\n",
       "      <td>0.014724</td>\n",
       "      <td>2.69</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>-1.358354</td>\n",
       "      <td>-1.340163</td>\n",
       "      <td>1.773209</td>\n",
       "      <td>0.379780</td>\n",
       "      <td>-0.503198</td>\n",
       "      <td>1.800499</td>\n",
       "      <td>0.791461</td>\n",
       "      <td>0.247676</td>\n",
       "      <td>-1.514654</td>\n",
       "      <td>...</td>\n",
       "      <td>0.247998</td>\n",
       "      <td>0.771679</td>\n",
       "      <td>0.909412</td>\n",
       "      <td>-0.689281</td>\n",
       "      <td>-0.327642</td>\n",
       "      <td>-0.139097</td>\n",
       "      <td>-0.055353</td>\n",
       "      <td>-0.059752</td>\n",
       "      <td>378.66</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>-0.966272</td>\n",
       "      <td>-0.185226</td>\n",
       "      <td>1.792993</td>\n",
       "      <td>-0.863291</td>\n",
       "      <td>-0.010309</td>\n",
       "      <td>1.247203</td>\n",
       "      <td>0.237609</td>\n",
       "      <td>0.377436</td>\n",
       "      <td>-1.387024</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.108300</td>\n",
       "      <td>0.005274</td>\n",
       "      <td>-0.190321</td>\n",
       "      <td>-1.175575</td>\n",
       "      <td>0.647376</td>\n",
       "      <td>-0.221929</td>\n",
       "      <td>0.062723</td>\n",
       "      <td>0.061458</td>\n",
       "      <td>123.50</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>-1.158233</td>\n",
       "      <td>0.877737</td>\n",
       "      <td>1.548718</td>\n",
       "      <td>0.403034</td>\n",
       "      <td>-0.407193</td>\n",
       "      <td>0.095921</td>\n",
       "      <td>0.592941</td>\n",
       "      <td>-0.270533</td>\n",
       "      <td>0.817739</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.009431</td>\n",
       "      <td>0.798278</td>\n",
       "      <td>-0.137458</td>\n",
       "      <td>0.141267</td>\n",
       "      <td>-0.206010</td>\n",
       "      <td>0.502292</td>\n",
       "      <td>0.219422</td>\n",
       "      <td>0.215153</td>\n",
       "      <td>69.99</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>2</td>\n",
       "      <td>-0.425966</td>\n",
       "      <td>0.960523</td>\n",
       "      <td>1.141109</td>\n",
       "      <td>-0.168252</td>\n",
       "      <td>0.420987</td>\n",
       "      <td>-0.029728</td>\n",
       "      <td>0.476201</td>\n",
       "      <td>0.260314</td>\n",
       "      <td>-0.568671</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.208254</td>\n",
       "      <td>-0.559825</td>\n",
       "      <td>-0.026398</td>\n",
       "      <td>-0.371427</td>\n",
       "      <td>-0.232794</td>\n",
       "      <td>0.105915</td>\n",
       "      <td>0.253844</td>\n",
       "      <td>0.081080</td>\n",
       "      <td>3.67</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>4</td>\n",
       "      <td>1.229658</td>\n",
       "      <td>0.141004</td>\n",
       "      <td>0.045371</td>\n",
       "      <td>1.202613</td>\n",
       "      <td>0.191881</td>\n",
       "      <td>0.272708</td>\n",
       "      <td>-0.005159</td>\n",
       "      <td>0.081213</td>\n",
       "      <td>0.464960</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.167716</td>\n",
       "      <td>-0.270710</td>\n",
       "      <td>-0.154104</td>\n",
       "      <td>-0.780055</td>\n",
       "      <td>0.750137</td>\n",
       "      <td>-0.257237</td>\n",
       "      <td>0.034507</td>\n",
       "      <td>0.005168</td>\n",
       "      <td>4.99</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7</td>\n",
       "      <td>-0.644269</td>\n",
       "      <td>1.417964</td>\n",
       "      <td>1.074380</td>\n",
       "      <td>-0.492199</td>\n",
       "      <td>0.948934</td>\n",
       "      <td>0.428118</td>\n",
       "      <td>1.120631</td>\n",
       "      <td>-3.807864</td>\n",
       "      <td>0.615375</td>\n",
       "      <td>...</td>\n",
       "      <td>1.943465</td>\n",
       "      <td>-1.015455</td>\n",
       "      <td>0.057504</td>\n",
       "      <td>-0.649709</td>\n",
       "      <td>-0.415267</td>\n",
       "      <td>-0.051634</td>\n",
       "      <td>-1.206921</td>\n",
       "      <td>-1.085339</td>\n",
       "      <td>40.80</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>7</td>\n",
       "      <td>-0.894286</td>\n",
       "      <td>0.286157</td>\n",
       "      <td>-0.113192</td>\n",
       "      <td>-0.271526</td>\n",
       "      <td>2.669599</td>\n",
       "      <td>3.721818</td>\n",
       "      <td>0.370145</td>\n",
       "      <td>0.851084</td>\n",
       "      <td>-0.392048</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.073425</td>\n",
       "      <td>-0.268092</td>\n",
       "      <td>-0.204233</td>\n",
       "      <td>1.011592</td>\n",
       "      <td>0.373205</td>\n",
       "      <td>-0.384157</td>\n",
       "      <td>0.011747</td>\n",
       "      <td>0.142404</td>\n",
       "      <td>93.20</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>9</td>\n",
       "      <td>-0.338262</td>\n",
       "      <td>1.119593</td>\n",
       "      <td>1.044367</td>\n",
       "      <td>-0.222187</td>\n",
       "      <td>0.499361</td>\n",
       "      <td>-0.246761</td>\n",
       "      <td>0.651583</td>\n",
       "      <td>0.069539</td>\n",
       "      <td>-0.736727</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.246914</td>\n",
       "      <td>-0.633753</td>\n",
       "      <td>-0.120794</td>\n",
       "      <td>-0.385050</td>\n",
       "      <td>-0.069733</td>\n",
       "      <td>0.094199</td>\n",
       "      <td>0.246219</td>\n",
       "      <td>0.083076</td>\n",
       "      <td>3.68</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10 rows × 31 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "  Time        V1        V2        V3        V4        V5        V6        V7  \\\n",
       "0    0 -1.359807 -0.072781  2.536347  1.378155 -0.338321  0.462388  0.239599   \n",
       "1    0  1.191857  0.266151  0.166480  0.448154  0.060018 -0.082361 -0.078803   \n",
       "2    1 -1.358354 -1.340163  1.773209  0.379780 -0.503198  1.800499  0.791461   \n",
       "3    1 -0.966272 -0.185226  1.792993 -0.863291 -0.010309  1.247203  0.237609   \n",
       "4    2 -1.158233  0.877737  1.548718  0.403034 -0.407193  0.095921  0.592941   \n",
       "5    2 -0.425966  0.960523  1.141109 -0.168252  0.420987 -0.029728  0.476201   \n",
       "6    4  1.229658  0.141004  0.045371  1.202613  0.191881  0.272708 -0.005159   \n",
       "7    7 -0.644269  1.417964  1.074380 -0.492199  0.948934  0.428118  1.120631   \n",
       "8    7 -0.894286  0.286157 -0.113192 -0.271526  2.669599  3.721818  0.370145   \n",
       "9    9 -0.338262  1.119593  1.044367 -0.222187  0.499361 -0.246761  0.651583   \n",
       "\n",
       "         V8        V9  ...         V21       V22       V23       V24  \\\n",
       "0  0.098698  0.363787  ...   -0.018307  0.277838 -0.110474  0.066928   \n",
       "1  0.085102 -0.255425  ...   -0.225775 -0.638672  0.101288 -0.339846   \n",
       "2  0.247676 -1.514654  ...    0.247998  0.771679  0.909412 -0.689281   \n",
       "3  0.377436 -1.387024  ...   -0.108300  0.005274 -0.190321 -1.175575   \n",
       "4 -0.270533  0.817739  ...   -0.009431  0.798278 -0.137458  0.141267   \n",
       "5  0.260314 -0.568671  ...   -0.208254 -0.559825 -0.026398 -0.371427   \n",
       "6  0.081213  0.464960  ...   -0.167716 -0.270710 -0.154104 -0.780055   \n",
       "7 -3.807864  0.615375  ...    1.943465 -1.015455  0.057504 -0.649709   \n",
       "8  0.851084 -0.392048  ...   -0.073425 -0.268092 -0.204233  1.011592   \n",
       "9  0.069539 -0.736727  ...   -0.246914 -0.633753 -0.120794 -0.385050   \n",
       "\n",
       "        V25       V26       V27       V28  Amount  Class  \n",
       "0  0.128539 -0.189115  0.133558 -0.021053  149.62      0  \n",
       "1  0.167170  0.125895 -0.008983  0.014724    2.69      0  \n",
       "2 -0.327642 -0.139097 -0.055353 -0.059752  378.66      0  \n",
       "3  0.647376 -0.221929  0.062723  0.061458  123.50      0  \n",
       "4 -0.206010  0.502292  0.219422  0.215153   69.99      0  \n",
       "5 -0.232794  0.105915  0.253844  0.081080    3.67      0  \n",
       "6  0.750137 -0.257237  0.034507  0.005168    4.99      0  \n",
       "7 -0.415267 -0.051634 -1.206921 -1.085339   40.80      0  \n",
       "8  0.373205 -0.384157  0.011747  0.142404   93.20      0  \n",
       "9 -0.069733  0.094199  0.246219  0.083076    3.68      0  \n",
       "\n",
       "[10 rows x 31 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.limit(10).toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28']\n"
     ]
    }
   ],
   "source": [
    "feature_columns = [col for col in df.columns if col.startswith(\"V\")]\n",
    "print(feature_columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>features</th>\n",
       "      <th>Class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[-1.3598071336738, -0.0727811733098497, 2.5363...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[1.19185711131486, 0.26615071205963, 0.1664801...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[-1.35835406159823, -1.34016307473609, 1.77320...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[-0.966271711572087, -0.185226008082898, 1.792...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[-1.15823309349523, 0.877736754848451, 1.54871...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            features  Class\n",
       "0  [-1.3598071336738, -0.0727811733098497, 2.5363...      0\n",
       "1  [1.19185711131486, 0.26615071205963, 0.1664801...      0\n",
       "2  [-1.35835406159823, -1.34016307473609, 1.77320...      0\n",
       "3  [-0.966271711572087, -0.185226008082898, 1.792...      0\n",
       "4  [-1.15823309349523, 0.877736754848451, 1.54871...      0"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vectorizer = VectorAssembler(inputCols = feature_columns, outputCol=\"features\")\n",
    "vectorizer.transform(df).select(\"features\", \"Class\").limit(5).toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomForestClassifier_4c658cfbc38faf430271"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "est = RandomForestClassifier()\n",
    "est.setMaxDepth(5)\n",
    "est.setLabelCol(\"Class\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cacheNodeIds: If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. (default: False)\n",
      "checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. (default: 10)\n",
      "featureSubsetStrategy: The number of features to consider for splits at each tree node. Supported options: auto, all, onethird, sqrt, log2, (0.0-1.0], [1-n]. (default: auto)\n",
      "featuresCol: features column name. (default: features)\n",
      "impurity: Criterion used for information gain calculation (case-insensitive). Supported options: entropy, gini (default: gini)\n",
      "labelCol: label column name. (default: label, current: Class)\n",
      "maxBins: Max number of bins for discretizing continuous features.  Must be >=2 and >= number of categories for any categorical feature. (default: 32)\n",
      "maxDepth: Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default: 5, current: 5)\n",
      "maxMemoryInMB: Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size. (default: 256)\n",
      "minInfoGain: Minimum information gain for a split to be considered at a tree node. (default: 0.0)\n",
      "minInstancesPerNode: Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1. (default: 1)\n",
      "numTrees: Number of trees to train (>= 1). (default: 20)\n",
      "predictionCol: prediction column name. (default: prediction)\n",
      "probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. (default: probability)\n",
      "rawPredictionCol: raw prediction (a.k.a. confidence) column name. (default: rawPrediction)\n",
      "seed: random seed. (default: -5387697053847413545)\n",
      "subsamplingRate: Fraction of the training data used for learning each decision tree, in range (0, 1]. (default: 1.0)\n"
     ]
    }
   ],
   "source": [
    "print(est.explainParams())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train, df_test = df.randomSplit(weights=[0.7, 0.3], seed = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "pipeline = Pipeline()\n",
    "pipeline.setStages([vectorizer, est])\n",
    "model = pipeline.fit(df_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test_pred = model.transform(df_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BinaryClassificationEvaluator_42c282003717c18c5038"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator = BinaryClassificationEvaluator()\n",
    "evaluator.setLabelCol(\"Class\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9691279828708574"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator.evaluate(model.transform(df_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.functions import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "test_accuracy = (df_test_pred\n",
    ".select(\"Class\", \"prediction\")\n",
    ".withColumn(\"isEqual\", expr(\"Class == prediction\"))\n",
    ".select(avg(expr(\"cast(isEqual as float)\")))\n",
    ".first())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Row(avg(CAST(isEqual AS FLOAT))=0.999285011017863)"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7787428245059321"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "treeEstimator = DecisionTreeClassifier()\n",
    "treeEstimator.setImpurity(\"entropy\")\n",
    "treeEstimator.setLabelCol(\"Class\")\n",
    "\n",
    "pipeline = Pipeline()\n",
    "pipeline.setStages([vectorizer, treeEstimator])\n",
    "model = pipeline.fit(df_train)\n",
    "evaluator.evaluate(model.transform(df_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9992732898870083"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_evaluator = MulticlassClassificationEvaluator()\n",
    "accuracy_evaluator.setLabelCol(\"Class\")\n",
    "accuracy_evaluator.setMetricName(\"accuracy\")\n",
    "accuracy_evaluator.evaluate(model.transform(df_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9992479231033246"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f1_evaluator = MulticlassClassificationEvaluator()\n",
    "f1_evaluator.setLabelCol(\"Class\")\n",
    "f1_evaluator.setMetricName(\"f1\")\n",
    "f1_evaluator.evaluate(model.transform(df_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
